import os
import random

import numpy as np
from torch.utils.data import Dataset

import re

from kor_char_parser import decompose_str_as_one_hot


def get_one_hot(targets, nb_classes):
    return np.eye(nb_classes, dtype=np.float32)[np.array(targets).reshape(-1)]


class AgDataset(Dataset):
    def __init__(self, data_info: list, data_label: list, max_length: int):
        """
        initializer

        :param dataset_path: 데이터셋 root path
        :param max_length: 문자열의 최대 길이
        """

        self.data_info = data_info
        self.data = []
        self.data_info_sum = []
        self.data_info_max = []
        self.data_info_mean = []
        self.data_info_min = []
        self.data_info_nature = []
        for i in range(len(data_info)):
            self.data_info_sum.append(switch_list_2_str(data_info[i][0]))
            self.data_info_max.append(switch_list_2_str(data_info[i][1]))
            self.data_info_mean.append(switch_list_2_str(data_info[i][2]))
            self.data_info_min.append(switch_list_2_str(data_info[i][3]))
            self.data_info_nature.append(switch_list_2_str(data_info[i][4]))

        self.data_info_sum = preprocess(self.data_info_sum, max_length)
        self.data_info_max = preprocess(self.data_info_max, max_length)
        self.data_info_mean = preprocess(self.data_info_mean, max_length)
        self.data_info_min = preprocess(self.data_info_min, max_length)
        # self.data_info_nature = preprocess(self.data_info_nature, max_length)

        # temp1, temp2 = zip(*self.data_info_sum)
        # print(temp1)
        # print(temp2)

        for i in range(len(self.data_info)):
            temp = [self.data_info_sum[i], self.data_info_max[i], self.data_info_mean[i], self.data_info_min[i],
                    self.data_info_nature[i]]
            self.data.append(temp)

        self.data_label = [np.float(x) for x in data_label]

        # data_review = os.path.join(dataset_path, 'data', mode + '-data.txt')
        # data_label = os.path.join(dataset_path, 'data', mode + '-label.txt')
        #
        # with open(data_review, 'rt', encoding='utf-8') as f:
        #     self.reviews = preprocess(f.readlines(), max_length)
        #
        # with open(data_label) as f:
        #     self.labels = [np.float32(x) for x in f.readlines()]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.data_label[idx]


def switch_list_2_str(data: list):
    res = str(data[0])
    for i in range(len(data)-1):
        res += " " + str(data[i+1])
    return res


def preprocess(data: list, max_length: int):
    vectorized_data = [decompose_str_as_one_hot(datum, warning=False) for datum in data]
    print("longest length: ", len(max(vectorized_data, key=len)))
    zero_padding = np.zeros((len(data), max_length), dtype=np.int32)
    lens1 = np.zeros((len(vectorized_data), max_length), dtype=np.int32)
    for idx, seq in enumerate(vectorized_data):
        length = min(len(seq), max_length)
        lens1[idx] = np.pad(np.arange(length) + 1, (0, max_length - length), 'constant')
        if length >= max_length:
            length = max_length
            zero_padding[idx, :length] = np.array(seq)[:length]
        else:
            zero_padding[idx, :length] = np.array(seq)
    return list(zip(zero_padding, lens1))
